{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用fastText对文本进行分类\n",
    "\n",
    "清华大学的新闻分类文本数据集下载：[https://thunlp.oss-cn-qingdao.aliyuncs.com/THUCNews.zip](https://thunlp.oss-cn-qingdao.aliyuncs.com/THUCNews.zip)\n",
    "\n",
    "下载后进行解压，把相应的中文目录替换成以下英文名，方便程序读取数据\n",
    "\n",
    "['affairs','constellation','economic','edu','ent','fashion','game','home','house','lottery','science','sports','society','stock']\n",
    "\n",
    "**第一步获取分类文本：输出数据格式： 样本 + 样本标签**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jieba\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "basedir = \"/home/python_home/WeiZhongChuang/ML/Embedding/THUCNews/\" #这是我的文件地址，需跟据文件夹位置进行更改\n",
    "dir_list = ['affairs','constellation','economic','edu','ent','fashion','game','home','house','lottery','science','sports','society','stock']\n",
    "##生成fastext的训练和测试数据集\n",
    "\n",
    "ftrain = open(\"news_fasttext_train.txt\",\"w\")\n",
    "ftest = open(\"news_fasttext_test.txt\",\"w\")\n",
    "\n",
    "num = -1\n",
    "for e in dir_list:\n",
    "    num += 1\n",
    "    indir = basedir + e + '/'\n",
    "    files = os.listdir(indir)\n",
    "    count = 0\n",
    "    for fileName in files:\n",
    "        count += 1            \n",
    "        filepath = indir + fileName\n",
    "        with open(filepath,'r') as fr:\n",
    "            text = fr.read()\n",
    "        text = str(text.encode(\"utf-8\"),'utf-8')\n",
    "        seg_text = jieba.cut(text.replace(\"\\t\",\" \").replace(\"\\n\",\" \"))\n",
    "        outline = \" \".join(seg_text)\n",
    "        outline = outline + \"\\t__label__\" + e + \"\\n\"\n",
    "#         print outline\n",
    "#         break\n",
    "\n",
    "        if count < 10000:\n",
    "            ftrain.write(outline)\n",
    "            ftrain.flush()\n",
    "            continue\n",
    "        elif count  < 20000:\n",
    "            ftest.write(outline)\n",
    "            ftest.flush()\n",
    "            continue\n",
    "        else:\n",
    "            break\n",
    "\n",
    "ftrain.close()\n",
    "ftest.close()\n",
    "print('完成输出数据！')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第二步,使用fastText进行训练模型(如果数据已经准备好，可以直接运行第二步)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练完成！\n"
     ]
    }
   ],
   "source": [
    "import logging\n",
    "logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)\n",
    "import fasttext\n",
    "#训练模型\n",
    "classifier = fasttext.train_supervised(\"news_fasttext_train.txt\",label_prefix=\"__label__\")\n",
    "\n",
    "#load训练好的模型\n",
    "#classifier = fasttext.load_model('news_fasttext.model.bin', label_prefix='__label__')\n",
    "print('训练完成！')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "precision： 0.7845266342650989\n"
     ]
    }
   ],
   "source": [
    "result = classifier.test(\"news_fasttext_test.txt\")\n",
    "print('precision：', result[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 由于fasttext貌似只提供全部结果的p值和r值，想要统计不同分类的结果，就需要自己写代码来实现了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['__label__economic', '__label__sports', '__label__fashion', '__label__stock', '__label__edu', '__label__lottery', '__label__home', '__label__house', '__label__game', '__label__constellation', '__label__ent', '__label__science', '__label__affairs', '__label__society']\n",
      "['home', 'sports', 'science', 'stock', 'game', 'economic', 'ent', 'house', 'fashion', 'edu', 'society', 'affairs']\n",
      "\n",
      "预测正确的各个类的数目: {'home': 9390, 'sports': 9273, 'science': 9507, 'stock': 4257, 'game': 9387, 'economic': 8844, 'ent': 8509, 'house': 8588, 'fashion': 1236, 'edu': 8165, 'society': 3467, 'affairs': 8318}\n",
      "\n",
      "测试数据集中各个类的数目: {'home': 10000, 'sports': 10000, 'science': 10000, 'stock': 10000, 'game': 10000, 'economic': 10000, 'ent': 10000, 'house': 10000, 'fashion': 3369, 'edu': 10000, 'society': 10000, 'affairs': 10000}\n",
      "\n",
      "预测结果中各个类的数目: {'__label__economic': 9489, '__label__sports': 13467, '__label__fashion': 1305, '__label__stock': 4896, '__label__edu': 8554, '__label__lottery': 396, '__label__home': 15112, '__label__house': 10030, '__label__game': 10532, '__label__constellation': 734, '__label__ent': 9305, '__label__science': 13474, '__label__affairs': 12543, '__label__society': 3532}\n",
      "\n",
      "home:\t p:0.621361\t r:0.939000\t f:0.747850\n",
      "sports:\t p:0.688572\t r:0.927300\t f:0.790301\n",
      "science:\t p:0.705581\t r:0.950700\t f:0.810003\n",
      "stock:\t p:0.869485\t r:0.425700\t f:0.571563\n",
      "game:\t p:0.891284\t r:0.938700\t f:0.914378\n",
      "economic:\t p:0.932027\t r:0.884400\t f:0.907589\n",
      "ent:\t p:0.914455\t r:0.850900\t f:0.881533\n",
      "house:\t p:0.856231\t r:0.858800\t f:0.857514\n",
      "fashion:\t p:0.947126\t r:0.366874\t f:0.528883\n",
      "edu:\t p:0.954524\t r:0.816500\t f:0.880134\n",
      "society:\t p:0.981597\t r:0.346700\t f:0.512415\n",
      "affairs:\t p:0.663159\t r:0.831800\t f:0.737967\n"
     ]
    }
   ],
   "source": [
    "labels_right = []\n",
    "texts = []\n",
    "with open(\"news_fasttext_test.txt\") as fr:\n",
    "    for line in fr:\n",
    "        line = str(line.encode(\"utf-8\"), 'utf-8').rstrip()\n",
    "        labels_right.append(line.split(\"\\t\")[1].replace(\"__label__\",\"\"))\n",
    "        texts.append(line.split(\"\\t\")[0])\n",
    "    #     print labels\n",
    "    #     print texts\n",
    "#     break\n",
    "labels_predict = [term[0] for term in classifier.predict(texts)[0]] #预测输出结果为二维形式\n",
    "# print labels_predict\n",
    "\n",
    "text_labels = list(set(labels_right))\n",
    "text_predict_labels = list(set(labels_predict))\n",
    "print(text_predict_labels)\n",
    "print(text_labels)\n",
    "print()\n",
    "\n",
    "A = dict.fromkeys(text_labels,0)  #预测正确的各个类的数目\n",
    "B = dict.fromkeys(text_labels,0)   #测试数据集中各个类的数目\n",
    "C = dict.fromkeys(text_predict_labels,0) #预测结果中各个类的数目\n",
    "for i in range(0,len(labels_right)):\n",
    "    B[labels_right[i]] += 1\n",
    "    C[labels_predict[i]] += 1\n",
    "    if labels_right[i] == labels_predict[i].replace('__label__', ''):\n",
    "        A[labels_right[i]] += 1\n",
    "\n",
    "print('预测正确的各个类的数目:', A) \n",
    "print()\n",
    "print('测试数据集中各个类的数目:', B)\n",
    "print()\n",
    "print('预测结果中各个类的数目:', C)\n",
    "print()\n",
    "#计算准确率，召回率，F值\n",
    "for key in B:\n",
    "    try:\n",
    "        r = float(A[key]) / float(B[key])\n",
    "        p = float(A[key]) / float(C['__label__' + key])\n",
    "        f = p * r * 2 / (p + r)\n",
    "        print(\"%s:\\t p:%f\\t r:%f\\t f:%f\" % (key,p,r,f))\n",
    "    except:\n",
    "        print(\"error:\", key, \"right:\", A.get(key,0), \"real:\", B.get(key,0), \"predict:\",C.get(key,0))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
